package top.codef.secure.login;

import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.mapping;
import static java.util.stream.Collectors.toList;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.security.authentication.AccountExpiredException;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.CredentialsExpiredException;
import org.springframework.security.authentication.DisabledException;
import org.springframework.security.authentication.InternalAuthenticationServiceException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsChecker;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.util.Assert;

import top.codef.secure.exceptions.CerberusAuthenticationException;
import top.codef.secure.login.interfaces.CerberusAuthenticationChecker;
import top.codef.secure.login.interfaces.CerberusUserDetailsService;
import top.codef.secure.properties.CerberusCustomLoginProperties;

public class CerberusAuthenticationProvider implements AuthenticationProvider {

	private final Map<String, List<CerberusAuthenticationChecker>> checkers;

	private final Map<String, CerberusUserDetailsService> detailsServiceMap;

	private final CerberusCustomLoginProperties cerberusCustomLoginProperties;

	private UserDetailsChecker preAuthenticationChecks = new DefaultPreAuthenticationChecks();

	private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks();

	private final Log logger = LogFactory.getLog(CerberusAuthenticationProvider.class);

	public CerberusAuthenticationProvider(List<CerberusAuthenticationChecker> checkers,
			List<CerberusUserDetailsService> userDetailsServices,
			CerberusCustomLoginProperties cerberusCustomLoginProperties) {
		this.checkers = checkers == null ? Collections.emptyMap()
				: checkers.stream().flatMap(x -> x.names().stream().map(y -> new CheckerForName(y, x)))
						.collect(groupingBy(CheckerForName::getName, mapping(CheckerForName::getChecker, toList())));
		detailsServiceMap = new HashMap<String, CerberusUserDetailsService>();
		userDetailsServices.forEach(x -> x.names().forEach(y -> detailsServiceMap.putIfAbsent(y, x)));
		this.cerberusCustomLoginProperties = cerberusCustomLoginProperties;
	}

	class CheckerForName {

		final String name;

		final CerberusAuthenticationChecker checker;

		public String getName() {
			return name;
		}

		public CerberusAuthenticationChecker getChecker() {
			return checker;
		}

		public CheckerForName(String name, CerberusAuthenticationChecker checker) {
			this.name = name;
			this.checker = checker;
		}

	}

	public Map<String, List<CerberusAuthenticationChecker>> getCheckers() {
		return checkers;
	}

	public Map<String, CerberusUserDetailsService> getDetailsServiceMap() {
		return detailsServiceMap;
	}

	public CerberusCustomLoginProperties getCerberusCustomLoginProperties() {
		return cerberusCustomLoginProperties;
	}

	public CerberusAuthenticationProvider(List<CerberusAuthenticationChecker> checkers,
			List<CerberusUserDetailsService> userDetailsServices,
			CerberusCustomLoginProperties cerberusCustomLoginProperties, UserDetailsChecker preAuthenticationChecks,
			UserDetailsChecker postAuthenticationChecks) {
		super();
		this.checkers = checkers == null ? Collections.emptyMap()
				: checkers.stream().flatMap(x -> x.names().stream().map(y -> new CheckerForName(y, x)))
						.collect(groupingBy(CheckerForName::getName, mapping(CheckerForName::getChecker, toList())));
		detailsServiceMap = new HashMap<String, CerberusUserDetailsService>();
		userDetailsServices.forEach(x -> x.names().forEach(y -> detailsServiceMap.putIfAbsent(y, x)));
		this.cerberusCustomLoginProperties = cerberusCustomLoginProperties;
		this.preAuthenticationChecks = preAuthenticationChecks;
		this.postAuthenticationChecks = postAuthenticationChecks;
	}

	@Override
	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
		Assert.isInstanceOf(CerberusAuthenticationToken.class, authentication,
				() -> "Only CerberusAuthenticationToken is supported");
		String principal = authentication.getPrincipal() == null ? "NONE_PROVIDED"
				: (String) authentication.getPrincipal();
		CerberusAuthenticationToken cerberusAuthenticationToken = (CerberusAuthenticationToken) authentication;
		UserDetails userDetails = null;
		try {
			userDetails = retrieveUser(principal, cerberusAuthenticationToken.getOptional(), authentication);
			Assert.notNull(userDetails, "user is null");
			preAuthenticationChecks.check(userDetails);
			final List<CerberusAuthenticationChecker> suitCheckers = checkers
					.get(cerberusAuthenticationToken.getOptional());
			if (suitCheckers == null || suitCheckers.isEmpty())
				throw new CerberusAuthenticationException(
						"check error , no suitable checkers: " + cerberusAuthenticationToken.getOptional());
			for (CerberusAuthenticationChecker cerberusAuthenticationChecker : suitCheckers) {
				if (!cerberusAuthenticationChecker.check(userDetails, cerberusAuthenticationToken))
					throw new CerberusAuthenticationException("cerberus check error！" + userDetails.getUsername());
			}
			postAuthenticationChecks.check(userDetails);
			return genericNewSuccessAuthentication(userDetails, cerberusAuthenticationToken);
		} catch (CerberusAuthenticationException e) {
			throw e;
		} catch (Exception e) {
			logger.error("authenticationError", e);
			throw new BadCredentialsException("用户验证错误", e);
		}
	}

	private Authentication genericNewSuccessAuthentication(UserDetails userDetails,
			CerberusAuthenticationToken authenticationToken) {
		return new CerberusAuthenticationToken(userDetails, authenticationToken);
	}

	private final UserDetails retrieveUser(String username, String optional, Authentication authentication)
			throws AuthenticationException {
		try {
			CerberusAuthenticationToken token = (CerberusAuthenticationToken) authentication;
			CerberusUserDetailsService userDetailsService = detailsServiceMap.get(optional);
			if (userDetailsService == null) {
				logger.warn("userDetailService is not found , check your code again");
				throw new NullPointerException("CerberusUserDetailsService is not found");
			}
			UserDetails loadedUser = userDetailsService.loadUserByUsername(username, token.getAdditionalParams());
			if (loadedUser == null) {
				throw new InternalAuthenticationServiceException(
						"UserDetailsService returned null, which is an interface contract violation");
			}
			return loadedUser;
		} catch (UsernameNotFoundException | InternalAuthenticationServiceException
				| CerberusAuthenticationException ex) {
			throw ex;
		} catch (Exception ex) {
			throw new InternalAuthenticationServiceException(ex.getMessage(), ex);
		}
	}

	@Override
	public boolean supports(Class<?> authentication) {
		return CerberusAuthenticationToken.class.isAssignableFrom(authentication);
	}

	private class DefaultPreAuthenticationChecks implements UserDetailsChecker {
		public void check(UserDetails user) {
			if (!user.isAccountNonLocked()) {
				logger.debug("User account is locked");
				throw new LockedException("User account is locked");
			}
			if (!user.isEnabled()) {
				logger.debug("User account is disabled");
				throw new DisabledException("User is disabled");
			}
			if (!user.isAccountNonExpired()) {
				logger.debug("User account is expired");
				throw new AccountExpiredException("User account has expired");
			}
		}
	}

	private class DefaultPostAuthenticationChecks implements UserDetailsChecker {
		public void check(UserDetails user) {
			if (!user.isCredentialsNonExpired()) {
				logger.debug("User account credentials have expired");
				throw new CredentialsExpiredException("User credentials have expired");
			}
		}
	}

}
